{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "@author: huangyongye <br/>\n",
    "@creat_date: 2017-05-17 <br/>\n",
    "在上个例子中，我们已经理解了在 TensorFlow 中如何来实现 LSTM， 在本例子中来实现以下 GRU。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracting MNIST_data/train-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/train-labels-idx1-ubyte.gz\n",
      "Extracting MNIST_data/t10k-images-idx3-ubyte.gz\n",
      "Extracting MNIST_data/t10k-labels-idx1-ubyte.gz\n",
      "(55000, 784)\n"
     ]
    }
   ],
   "source": [
    "# -*- coding:utf-8 -*-\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "from tensorflow.contrib import rnn\n",
    "from tensorflow.examples.tutorials.mnist import input_data\n",
    "import time \n",
    "\n",
    "# 设置 GPU 按需增长\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "sess = tf.Session(config=config)\n",
    "\n",
    "# 首先导入数据，看一下数据的形式\n",
    "mnist = input_data.read_data_sets('MNIST_data', one_hot=True)\n",
    "print mnist.train.images.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 一、首先设置好模型用到的各个超参数 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "lr = 1e-3\n",
    "# 在训练和测试的时候，我们想用不同的 batch_size.所以采用占位符的方式\n",
    "batch_size = tf.placeholder(tf.int32)  # 注意类型必须为 tf.int32\n",
    "# batch_size = 128\n",
    "\n",
    "# 每个时刻的输入特征是28维的，就是每个时刻输入一行，一行有 28 个像素\n",
    "input_size = 28\n",
    "# 时序持续长度为28，即每做一次预测，需要先输入28行\n",
    "timestep_size = 28\n",
    "# 隐含层的数量\n",
    "hidden_size = 256\n",
    "# LSTM layer 的层数\n",
    "layer_num = 3\n",
    "# 最后输出分类类别数量，如果是回归预测的话应该是 1\n",
    "class_num = 10\n",
    "\n",
    "_X = tf.placeholder(tf.float32, [None, 784])\n",
    "y = tf.placeholder(tf.float32, [None, class_num])\n",
    "keep_prob = tf.placeholder(tf.float32)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 二、开始搭建 GRU 模型，和 LSTM 模型基本一致 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# 把784个点的字符信息还原成 28 * 28 的图片\n",
    "# 下面几个步骤是实现 RNN / LSTM 的关键\n",
    "####################################################################\n",
    "# **步骤1：RNN 的输入shape = (batch_size, timestep_size, input_size) \n",
    "X = tf.reshape(_X, [-1, 28, 28])\n",
    "\n",
    "# 在 tf 1.0.0 版本中，可以使用上面的 三个步骤创建多层 lstm， 但是在 tf 1.1.0 版本中，可以通过下面方式来创建\n",
    "def gru_cell():\n",
    "    cell = rnn.GRUCell(hidden_size, reuse=tf.get_variable_scope().reuse)\n",
    "    return rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)\n",
    "mgru_cell = tf.contrib.rnn.MultiRNNCell([gru_cell() for _ in range(layer_num)], state_is_tuple = True)\n",
    "\n",
    "\n",
    "# **步骤5：用全零来初始化state\n",
    "init_state = mgru_cell.zero_state(batch_size, dtype=tf.float32)\n",
    "\n",
    "# **步骤6：方法一，调用 dynamic_rnn() 来让我们构建好的网络运行起来\n",
    "# ** 当 time_major==False 时， outputs.shape = [batch_size, timestep_size, hidden_size] \n",
    "# ** 所以，可以取 h_state = outputs[:, -1, :] 作为最后输出\n",
    "# ** state.shape = [layer_num, 2, batch_size, hidden_size], \n",
    "# ** 或者，可以取 h_state = state[-1]作为最后输出\n",
    "# ** 最后输出维度是 [batch_size, hidden_size]\n",
    "# outputs, state = tf.nn.dynamic_rnn(mgru_cell, inputs=X, initial_state=init_state, time_major=False)\n",
    "# h_state = state[-1]\n",
    "\n",
    "# *************** 为了更好的理解 LSTM 工作原理，我们把上面 步骤6 中的函数自己来实现 ***************\n",
    "# 通过查看文档你会发现， RNNCell 都提供了一个 __call__()函数，我们可以用它来展开实现LSTM按时间步迭代。\n",
    "# **步骤6：方法二，按时间步展开计算\n",
    "outputs = list()\n",
    "state = init_state\n",
    "with tf.variable_scope('RNN'):\n",
    "    for timestep in range(timestep_size):\n",
    "        if timestep > 0:\n",
    "            tf.get_variable_scope().reuse_variables()\n",
    "        # 这里的state保存了每一层 LSTM 的状态\n",
    "        (cell_output, state) = mgru_cell(X[:, timestep, :],state)\n",
    "        outputs.append(cell_output)\n",
    "h_state = outputs[-1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ** 三、最后设置 loss function 和 优化器，展开训练并完成测试 **"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter1, step 500, training accuracy 0.984375\n",
      "Iter2, step 1000, training accuracy 0.984375\n",
      "Iter3, step 1500, training accuracy 0.976562\n",
      "Iter4, step 2000, training accuracy 1\n",
      "Iter5, step 2500, training accuracy 0.992188\n",
      "Iter6, step 3000, training accuracy 0.992188\n",
      "Iter8, step 3500, training accuracy 0.992188\n",
      "Iter9, step 4000, training accuracy 1\n",
      "test accuracy 0.9899\n",
      "Time cost 244.955\n"
     ]
    }
   ],
   "source": [
    "############################################################################\n",
    "# 以下部分其实和之前写的多层 CNNs 来实现 MNIST 分类是一样的。\n",
    "# 只是在测试的时候也要设置一样的 batch_size.\n",
    "\n",
    "# 上面 GRU 部分的输出会是一个 [hidden_size] 的tensor，我们要分类的话，还需要接一个 softmax 层\n",
    "# 首先定义 softmax 的连接权重矩阵和偏置\n",
    "# out_W = tf.placeholder(tf.float32, [hidden_size, class_num], name='out_Weights')\n",
    "# out_bias = tf.placeholder(tf.float32, [class_num], name='out_bias')\n",
    "# 开始训练和测试\n",
    "W = tf.Variable(tf.truncated_normal([hidden_size, class_num], stddev=0.1), dtype=tf.float32)\n",
    "bias = tf.Variable(tf.constant(0.1,shape=[class_num]), dtype=tf.float32)\n",
    "y_pre = tf.nn.softmax(tf.matmul(h_state, W) + bias)\n",
    "\n",
    "\n",
    "# 损失和评估函数\n",
    "cross_entropy = -tf.reduce_mean(y * tf.log(y_pre))\n",
    "train_op = tf.train.AdamOptimizer(lr).minimize(cross_entropy)\n",
    "\n",
    "correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(y,1))\n",
    "accuracy = tf.reduce_mean(tf.cast(correct_prediction, \"float\"))\n",
    "\n",
    "time0 = time.time()\n",
    "sess.run(tf.global_variables_initializer())\n",
    "for i in range(4000):\n",
    "    _batch_size = 128\n",
    "    batch = mnist.train.next_batch(_batch_size)\n",
    "    if (i+1)%500 == 0:\n",
    "        train_accuracy = sess.run(accuracy, feed_dict={\n",
    "            _X:batch[0], y: batch[1], keep_prob: 1.0, batch_size: _batch_size})\n",
    "        # 已经迭代完成的 epoch 数: mnist.train.epochs_completed\n",
    "        print \"Iter%d, step %d, training accuracy %g\" % ( mnist.train.epochs_completed, (i+1), train_accuracy)\n",
    "    sess.run(train_op, feed_dict={_X: batch[0], y: batch[1], keep_prob: 0.5, batch_size: _batch_size})\n",
    "\n",
    "# 计算测试数据的准确率\n",
    "print \"test accuracy %g\"% sess.run(accuracy, feed_dict={\n",
    "    _X: mnist.test.images, y: mnist.test.labels, keep_prob: 1.0, batch_size:mnist.test.images.shape[0]})\n",
    "print 'Time cost %g' % (time.time() - time0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "和前面 lstm 的例子比较一下，二者的准确率并没有太大的差别。在同样迭代 4000 次的条件下：\n",
    "- 运行时间为： lstm 248.446s　　　gru 244.955\n",
    "- 占用显存为： lstm 727m　　　　　gru 471m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 四、可视化看看 GRU 的是怎么做分类的"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "毕竟 GRU更多的是用来做时序相关的问题，要么是文本，要么是序列预测之类的，所以很难像 CNNs 一样非常直观地看到每一层中特征的变化。在这里，我想通过可视化的方式来帮助大家理解 GRU 是怎么样一步一步地把图片正确的给分类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 784) (5, 10)\n",
      "_outputs.shape = (28, 5, 256)\n",
      "arr_state.shape = (3, 5, 256)\n"
     ]
    }
   ],
   "source": [
    "# 手写的结果 shape\n",
    "_batch_size = 5\n",
    "X_batch, y_batch = mnist.test.next_batch(_batch_size)\n",
    "print X_batch.shape, y_batch.shape\n",
    "_outputs, _state = np.array(sess.run([outputs, state],feed_dict={\n",
    "            _X: X_batch, y: y_batch, keep_prob: 1.0, batch_size: _batch_size}))\n",
    "print '_outputs.shape =', np.asarray(_outputs).shape\n",
    "print 'arr_state.shape =', np.asarray(_state).shape\n",
    "# 可见 outputs.shape = [ batch_size, timestep_size, hidden_size]\n",
    "# state.shape = [layer_num, 2, batch_size, hidden_size]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(5, 784) (5, 10)\n",
      "_outputs.shape = (28, 5, 256)\n",
      "arr_state.shape = (3, 5, 256)\n"
     ]
    }
   ],
   "source": [
    "_batch_size = 5\n",
    "X_batch, y_batch = mnist.test.next_batch(_batch_size)\n",
    "print X_batch.shape, y_batch.shape\n",
    "_outputs, _state = sess.run([outputs, state],feed_dict={\n",
    "            _X: X_batch, y: y_batch, keep_prob: 1.0, batch_size: _batch_size})\n",
    "print '_outputs.shape =', np.asarray(_outputs).shape\n",
    "print 'arr_state.shape =', np.asarray(_state).shape\n",
    "# 可见 outputs.shape = [ batch_size, timestep_size, hidden_size]\n",
    "# state.shape = [layer_num, 2, batch_size, hidden_size]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "看下面我找了一个字符 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]\n"
     ]
    }
   ],
   "source": [
    "print mnist.train.labels[6]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们先来看看这个字符样子,上半部分还挺像 2 来的"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAD8CAYAAAC4nHJkAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAADdhJREFUeJzt3X+MFPUZx/HPA4VgLEYterncgVQlpwYjmos2hjSYloYa\nDDQhpPiHNDaef1isSWM09o+SlJqmsTX1H5MjEFFboUFR0pgSJE1pEyQehopg8Se1R5AfUsKPGBDv\n6R87NCfefHdvd3Znj+f9Si63O8/OzpPJfW5mdnbma+4uAPGMK7sBAOUg/EBQhB8IivADQRF+ICjC\nDwRF+IGgCD8QFOEHgvpaKxdmZnydEGgyd7daXtfQlt/M5pnZXjN738webeS9ALSW1fvdfjMbL+ld\nSXMlDUp6Q9ISd9+TmIctP9Bkrdjy3yrpfXf/0N3PSForaUED7weghRoJf5ek/wx7PphN+xIz6zOz\nATMbaGBZAArW9A/83L1fUr/Ebj/QThrZ8u+XNHXY8+5sGoAxoJHwvyFphpl908wmSvqhpI3FtAWg\n2ere7Xf3s2b2E0mbJI2XtNrddxfWGYCmqvtUX10L45gfaLqWfMkHwNhF+IGgCD8QFOEHgiL8QFCE\nHwiK8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii/EBQ\nhB8IivADQRF+ICjCDwRF+IGgCD8QFOEHgiL8QFB1D9EtSWa2T9IJSV9IOuvuvUU0hS+bOHFisj5l\nypTc2qJFi5LzzpgxI1nv6upK1seNS28/FixYkFtrdITodevWJesrVqzIre3ezWjyDYU/c4e7Hyng\nfQC0ELv9QFCNht8lvWZmO8ysr4iGALRGo7v9s919v5ldKWmzmf3L3bcOf0H2T4F/DECbaWjL7+77\ns9+HJG2QdOsIr+l3914+DATaS93hN7OLzWzyuceSvifp7aIaA9Bcjez2d0jaYGbn3ueP7v6XQroC\n0HTW6LnWUS3MrHULayPZP8hcixcvTtaXL1+erPf09Iy2pZodPHgwWR8aGkrWz5w5k1s7ceJEct6r\nrroqWZ88eXKy/umnn+bWpk2blpz3s88+S9bbmbun/+AynOoDgiL8QFCEHwiK8ANBEX4gKMIPBMWp\nvhaYMGFCsn769Olk/ezZs8n6pk2bcmvr169Pzrt3795kfWBgIFmv1lsjuru7k/U1a9Yk63fccUdu\n7ZJLLknOe/LkyWS9nXGqD0AS4QeCIvxAUIQfCIrwA0ERfiAowg8EVcTde1FFtXPh9957b7K+bdu2\nZL3aufqxanBwMFk/depUizq5MLHlB4Ii/EBQhB8IivADQRF+ICjCDwRF+IGguJ4fbWvmzJnJ+tat\nW5P11G3Hb7rppuS8qVuOtzuu5weQRPiBoAg/EBThB4Ii/EBQhB8IivADQVW9nt/MVkuaL+mQu8/M\npl0uaZ2k6ZL2SVrs7v9tXpu4EKXuqy9Jzz//fLJ+6aWXJusLFy7MrY3l8/hFqWXL/4ykeedNe1TS\nFnefIWlL9hzAGFI1/O6+VdLR8yYvkHRuuJQ1kvL/xQJoS/Ue83e4+4Hs8SeSOgrqB0CLNHwPP3f3\n1Hf2zaxPUl+jywFQrHq3/AfNrFOSst+H8l7o7v3u3uvuvXUuC0AT1Bv+jZKWZo+XSnqlmHYAtErV\n8JvZC5K2Seoxs0Ez+7GkX0uaa2bvSfpu9hzAGML1/Be48ePHJ+tXXnllsn7NNdck652dncn63Xff\nnVu76667kvMeO3YsWZ8/f36yvn379txaK//uW43r+QEkEX4gKMIPBEX4gaAIPxAU4QeCYojuC9zm\nzZuT9Tlz5rSmkREcPnw4Wb/++uuT9aNHz7/eDKPBlh8IivADQRF+ICjCDwRF+IGgCD8QFOEHguI8\n/wVuy5YtyXpXV1eyPmnSpGTdLH31aHd3d27tiiuuSM47bdq0ZJ3z/I1hyw8ERfiBoAg/EBThB4Ii\n/EBQhB8IivADQXHrbjRk3Lj09uPJJ5/MrS1btiw575IlS5L1devWJetRcetuAEmEHwiK8ANBEX4g\nKMIPBEX4gaAIPxBU1ev5zWy1pPmSDrn7zGzackn3STp34/XH3P3VZjWJ9jU0NJSsr1ixIrdW7Tx/\ntfv2ozG1bPmfkTRvhOlPuvus7IfgA2NM1fC7+1ZJ3DIFuMA0csy/zMzeMrPVZnZZYR0BaIl6w/+0\npKslzZJ0QNJv815oZn1mNmBmA3UuC0AT1BV+dz/o7l+4+5CklZJuTby239173b233iYBFK+u8JtZ\n57CnP5D0djHtAGiVWk71vSBpjqQpZjYo6ReS5pjZLEkuaZ+k+5vYI4AmqBp+dx/poupVTegFF6B5\n80Y6S1ybXbt2FdgJzsc3/ICgCD8QFOEHgiL8QFCEHwiK8ANBMUQ3mmrRokV1z/vBBx8U2AnOx5Yf\nCIrwA0ERfiAowg8ERfiBoAg/EBThB4JiiG40pKenJ1nfs2dPbu3w4cO5NUm67rrrkvVjx44l61Ex\nRDeAJMIPBEX4gaAIPxAU4QeCIvxAUIQfCIrr+ZE0efLkZP25555L1s3yTzmvXLkyOS/n8ZuLLT8Q\nFOEHgiL8QFCEHwiK8ANBEX4gKMIPBFX1PL+ZTZX0rKQOSS6p391/b2aXS1onabqkfZIWu/t/m9cq\nmuGWW25J1l9++eVkvbu7O1lfu3Ztbu2JJ55IzovmqmXLf1bSz9z9BknfkvSAmd0g6VFJW9x9hqQt\n2XMAY0TV8Lv7AXd/M3t8QtI7krokLZC0JnvZGkkLm9UkgOKN6pjfzKZLulnSdkkd7n4gK32iymEB\ngDGi5u/2m9nXJb0o6SF3Pz78O9vu7nn35zOzPkl9jTYKoFg1bfnNbIIqwf+Du7+UTT5oZp1ZvVPS\noZHmdfd+d+91994iGgZQjKrht8omfpWkd9z9d8NKGyUtzR4vlfRK8e0BaJaqt+42s9mS/i5pl6Sh\nbPJjqhz3/0nSNEn/VuVU39Eq78Wtu5vg2muvza098sgjyXnvueeeZH3ChAnJerXLch9++OHc2vHj\nx5Pzoj613rq76jG/u/9DUt6bfWc0TQFoH3zDDwiK8ANBEX4gKMIPBEX4gaAIPxDUmLp192233ZZb\nq3a+udr57J07d9bVkyT19qa/vDhx4sRk/cEHH0zWq112m7qsdtKkScl5P/roo2R99erVyfrjjz+e\nrLdyCHiMDlt+ICjCDwRF+IGgCD8QFOEHgiL8QFCEHwhqTJ3n7+npya3NnDkzOe+qVauS9ddffz1Z\nv/HGG3Nrt99+e3LeceMa+x/78ccfJ+uvvvpqbu2pp55Kzrtjx45k/dSpU8k6xi62/EBQhB8IivAD\nQRF+ICjCDwRF+IGgCD8QVNX79he6sAbv23/RRRfl1qoNJT137txGFq0jR47k1tavX5+cd9u2bcn6\n559/nqxv2LAhWT99+nSyjlhqvW8/W34gKMIPBEX4gaAIPxAU4QeCIvxAUIQfCKrqeX4zmyrpWUkd\nklxSv7v/3syWS7pP0uHspY+5e/6F5Wr8PD+A6mo9z19L+Dsldbr7m2Y2WdIOSQslLZZ00t2fqLUp\nwg80X63hr3onH3c/IOlA9viEmb0jqaux9gCUbVTH/GY2XdLNkrZnk5aZ2VtmttrMLsuZp8/MBsxs\noKFOARSq5u/2m9nXJf1N0q/c/SUz65B0RJXPAX6pyqHBvVXeg91+oMkKO+aXJDObIOnPkja5++9G\nqE+X9Gd3T95Fk/ADzVfYhT1mZpJWSXpnePCzDwLP+YGkt0fbJIDy1PJp/2xJf5e0S9JQNvkxSUsk\nzVJlt3+fpPuzDwdT78WWH2iyQnf7i0L4gebjen4ASYQfCIrwA0ERfiAowg8ERfiBoAg/EBThB4Ii\n/EBQhB8IivADQRF+ICjCDwRF+IGgqt7As2BHJP172PMp2bR21K69tWtfEr3Vq8jerqr1hS29nv8r\nCzcbcPfe0hpIaNfe2rUvid7qVVZv7PYDQRF+IKiyw99f8vJT2rW3du1Lord6ldJbqcf8AMpT9pYf\nQElKCb+ZzTOzvWb2vpk9WkYPecxsn5ntMrOdZQ8xlg2DdsjM3h427XIz22xm72W/RxwmraTelpvZ\n/mzd7TSzO0vqbaqZ/dXM9pjZbjP7aTa91HWX6KuU9dby3X4zGy/pXUlzJQ1KekPSEnff09JGcpjZ\nPkm97l76OWEz+7akk5KePTcakpn9RtJRd/919o/zMnd/pE16W65RjtzcpN7yRpb+kUpcd0WOeF2E\nMrb8t0p6390/dPczktZKWlBCH23P3bdKOnre5AWS1mSP16jyx9NyOb21BXc/4O5vZo9PSDo3snSp\n6y7RVynKCH+XpP8Mez6o9hry2yW9ZmY7zKyv7GZG0DFsZKRPJHWU2cwIqo7c3ErnjSzdNuuunhGv\ni8YHfl81291nSfq+pAey3du25JVjtnY6XfO0pKtVGcbtgKTfltlMNrL0i5Iecvfjw2tlrrsR+ipl\nvZUR/v2Spg573p1Nawvuvj/7fUjSBlUOU9rJwXODpGa/D5Xcz/+5+0F3/8LdhyStVInrLhtZ+kVJ\nf3D3l7LJpa+7kfoqa72VEf43JM0ws2+a2URJP5S0sYQ+vsLMLs4+iJGZXSzpe2q/0Yc3SlqaPV4q\n6ZUSe/mSdhm5OW9kaZW87tpuxGt3b/mPpDtV+cT/A0k/L6OHnL6ulvTP7Gd32b1JekGV3cDPVfls\n5MeSviFpi6T3JL0m6fI26u05VUZzfkuVoHWW1NtsVXbp35K0M/u5s+x1l+irlPXGN/yAoPjADwiK\n8ANBEX4gKMIPBEX4gaAIPxAU4QeCIvxAUP8Dq9eBRTuewxMAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f146e575ed0>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "X3 = mnist.train.images[6]\n",
    "img3 = X3.reshape([28, 28])\n",
    "plt.imshow(img3, cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们看看在分类的时候，一行一行地输入，分为各个类别的概率会是什么样子的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(28, 1, 256)\n",
      "(28, 256)\n"
     ]
    }
   ],
   "source": [
    "X3.shape = [-1, 784]\n",
    "y_batch = mnist.train.labels[0]\n",
    "y_batch.shape = [-1, class_num]\n",
    "\n",
    "X3_outputs = np.array(sess.run(outputs, feed_dict={\n",
    "            _X: X3, y: y_batch, keep_prob: 1.0, batch_size: 1}))\n",
    "print X3_outputs.shape\n",
    "X3_outputs.shape = [28, hidden_size]\n",
    "print X3_outputs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD8CAYAAABw1c+bAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABaNJREFUeJzt3UFu2zgAQNF4TpJj5mI5n7oYdFEgsB1LtEj997ZFWoWS\nPmmZdm/btn0A0PHf2QcAwHsJP0CM8APECD9AjPADxAg/QIzwA8QIP0CM8APECD9AjPADxAj/m3x+\nfW+fX98/fjHSvT+b1WrHuxJjO86K99oIS4f/jJi6aOZ077w4Z/CvpcP/qkeTwquhGBWYM8I128po\ntuOBlU0ffjc7wLGmDz/jzLaKnu14SjwqaxF+DiciMDfhJ+2siegqE6BXaY/NOD7CD4PMeMMzl7Mm\nTuEHiBF+uAiPXXiW8ANMaORELvywkD0xePergZHherRzzCuf+4Qf4AlXmkyEHyBG+AFihB9gQXse\nPQk/wMdab5zvJfwAMcIPECP8ADHCDxAj/AAxwg8QI/wAMcIPEHPbtqU+dwDATlb8ADHCDxAj/AAx\nwg8QI/wAMZcNv/93c07OScdq9+Bqx7vHZcMPwM+EHyBG+AFihB8gRvgBYoSfS6jsxoAjCP+BxIcz\nlbYjso/wv4mbEq5rtXtb+AFihB8gRvhJW+0lOhxB+H8gBsCVTR9+EQY41vThv+cqO2Wu8nvMytjC\nv5YOPx0mRziO8APECP8CzljpWmHDdQk/LzEpwLqEHyAmGf49jzGutNL1OAeakuEHnndvcWDxMM7I\nsRV+Drfn1ZSIwHjCDwvxmJK/9pxP4YcTiPB6rnTOhB8GWSkUKx3rI7M9MpzxPZLbtk0zPgC8gRU/\nQIzwA8QIP0CM8APECD9AjPD/0kzbxK5mtm14I1V+zzOUrqNXCT9AjPADxAg/QIzwA8QIP0CM8APE\nCD9AjPADxAg/QIzwA8QIP0CM8APECD9AjPADxAg/RPiqYv5Kht/3dQNlyfDzmMkRrkv4AWKEHyBG\n+AFihB8gRvg5nDeFqVntmhd+gJilw2/LIfAbmvG/6cP/6knac4JH/Ny943l0rDNeqGeM7atjdKWx\n3XMd7fk33308Zxh1H87Yhdu2LXNeADjA9Ct+AI4l/AAxwg8QI/wAMcIPECP8v7TS9rTVrLb9bzXG\ndpzVxlb4AWKEHyBG+AFihB8gRvgBYoQfIEb4AWKEH+Cj9TkS4QeIEX6AGOEHiBF+gJilw196Mwbg\nKEuHfyUmKWAWwg8QI/wczisbmJvwA8Qkw+95O6ty7XKEZPgByoQfIEb4uQSPP+B5wg9wkrPesxF+\ngAXtmTAuG/49M+m7f+6svxdj+4w994PxndNt25wXgJLLrvgB+JnwA8QIP0CM8APECD9AjPADxCTD\nb3/x9VTOp2t3nNLYJsMPUCb8ADHCDxAj/AAxwg8QI/wAMcIPECP8ADHCz+EqH4KBVQk/cJeJ/HqE\nHyBG+N+k9D0gwNyEHyBG+H/Jqh1YnfADxAg/QIzwA8QIP0CM8APECD9Mxs4xRhP+BQjBOMaWIuEH\niBF+XmKlzCO+pmRews80hGI9ztk+Z43fbducM4ASK36AGOEHiBF+gBjhB4gRfoAY4QeISYbf3uOx\njO04rt1xSmObDD9AmfADxAg/QIzwA8QIP0CM8APECD9AjPADxAg/QIzwA8QIP0CM8APECD9AjPAD\nxAg/QIzwA8QIP0CM8APECD9AjPADxAg/QIzwA8QIP0CM8APE3LZtO/sYAHgjK36AGOEHiBF+gBjh\nB4gRfoAY4QeISYb/8+t7+/z6to91EGM7jmt3nNLYJsMPUCb8ADHCDxAj/AAxwg8QI/wAMcIPECP8\nADHCDxAj/AAxwg8QI/wAMcIPECP8ADHCDxAj/AAxwg8QI/wAMcIPECP8ADHCDxAj/AAxwg8QI/wA\nMbdt284+BgDeyIofIEb4AWKEHyBG+AFihB8gRvgBYoQfIEb4AWKEHyBG+AFihB8gRvgBYoQfIEb4\nAWKEHyBG+AFihB8gRvgBYoQfIEb4AWKEHyBG+AFihB8g5g+6AjEjRuOySgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f146e50a090>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "h_W = sess.run(W, feed_dict={\n",
    "            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})\n",
    "h_bias = sess.run(bias, feed_dict={\n",
    "            _X:X3, y: y_batch, keep_prob: 1.0, batch_size: 1})\n",
    "h_bias.shape = [-1, 10]\n",
    "\n",
    "bar_index = range(class_num)\n",
    "for i in xrange(X3_outputs.shape[0]):\n",
    "    plt.subplot(7, 4, i+1)\n",
    "    X3_h_shate = X3_outputs[i, :].reshape([-1, hidden_size])\n",
    "    pro = sess.run(tf.nn.softmax(tf.matmul(X3_h_shate, h_W) + h_bias))\n",
    "    plt.bar(bar_index, pro[0], width=0.2 , align='center')\n",
    "    plt.axis('off')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在上面的图中，为了更清楚地看到线条的变化，我把坐标都去了，每一行显示了 4 个图，共有 7 行，表示了一行一行读取过程中，模型对字符的识别。可以看到，在只看到前面的几行像素时，模型根本认不出来是什么字符，随着看到的像素越来越多，最后就基本确定了它是字符 3."
   ]
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
